{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import kornia as tgm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Parameters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dir = 'data'\n",
    "src_name = 'img1.ppm'  # source image file\n",
    "dst_name = 'img2.ppm'  # destinatipn image file\n",
    "learning_rate = 1e-3\n",
    "num_iterations = 400  \n",
    "log_interval = 100  # print log every 200 iterations\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device('cuda' if use_cuda else 'cpu')\n",
    "print('Using ', device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_image(file_name):\n",
    "    \"\"\"Loads the image with OpenCV and converts to torch.Tensor                                      \n",
    "    \"\"\"\n",
    "    assert os.path.isfile(file_name), \"Invalid file {}\".format(file_name)\n",
    "\n",
    "    # load image with OpenCV                                                                         \n",
    "    img = cv2.imread(file_name, cv2.IMREAD_COLOR)\n",
    "\n",
    "    # convert image to torch tensor                                                                  \n",
    "    tensor = tgm.utils.image_to_tensor(img).float() / 255.\n",
    "    tensor = tensor.view(1, *tensor.shape)  # 1xCxHxW\n",
    "    return tensor, img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Container to hold the homography as a trainable parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyHomography(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyHomography, self).__init__()\n",
    "        self.homo = nn.Parameter(torch.Tensor(3, 3))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        torch.nn.init.eye_(self.homo)\n",
    "\n",
    "    def forward(self):\n",
    "        return torch.unsqueeze(self.homo, dim=0)  # 1x3x3  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read images and convert  to tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "img_src_t, img_src = load_image(os.path.join(input_dir, src_name))\n",
    "img_dst_t, img_dst = load_image(os.path.join(input_dir, dst_name))\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)\n",
    "fig.set_figheight(15)\n",
    "fig.set_figwidth(15)\n",
    "ax1.imshow(img_src[:,:,::-1])\n",
    "ax1.set_title('Source image')\n",
    "ax2.imshow(img_dst[:,:,::-1])\n",
    "ax2.set_title('Destination image')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the warper and the homography"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "height, width = img_src_t.shape[-2:]\n",
    "warper = tgm.HomographyWarper(height, width) #todo comment\n",
    "dst_homo_src = MyHomography().to(device)\n",
    "optimizer = optim.Adam(dst_homo_src.parameters(), lr=learning_rate)\n",
    "# send data to device\n",
    "img_src_t, img_dst_t = img_src_t.to(device), img_dst_t.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_rectangle(image, dst_homo_src):\n",
    "    height, width = image.shape[:2]\n",
    "    pts_src = torch.FloatTensor([[\n",
    "        [-1, -1],  # top-left\n",
    "        [1, -1],  # bottom-left\n",
    "        [1, 1],  # bottom-right\n",
    "        [-1, 1],  # top-right\n",
    "    ]]).to(dst_homo_src.device)\n",
    "    # transform points\n",
    "    pts_dst = tgm.transform_points(torch.inverse(dst_homo_src), pts_src)\n",
    "\n",
    "    def compute_factor(size):\n",
    "        return 1.0 * size / 2\n",
    "\n",
    "    def convert_coordinates_to_pixel(coordinates, factor):\n",
    "        return factor * (coordinates + 1.0)\n",
    "    # compute convertion factor\n",
    "    x_factor = compute_factor(width - 1)\n",
    "    y_factor = compute_factor(height - 1)\n",
    "    pts_dst = pts_dst.cpu().squeeze().detach().numpy()\n",
    "    pts_dst[..., 0] = convert_coordinates_to_pixel(\n",
    "        pts_dst[..., 0], x_factor)\n",
    "    pts_dst[..., 1] = convert_coordinates_to_pixel(\n",
    "        pts_dst[..., 1], y_factor)\n",
    "\n",
    "    # do the actual drawing\n",
    "    for i in range(4):\n",
    "        pt_i, pt_ii = tuple(pts_dst[i % 4]), tuple(pts_dst[(i + 1) % 4])\n",
    "        image = cv2.line(image, pt_i, pt_ii, (255, 0, 0), 3)\n",
    "        return image\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main optimization loop\n",
    "\n",
    "This is the loss function to minimize the photometric error:\n",
    " $ L = \\sum |I_{ref} - \\omega(I_{dst}, H_{ref}^{dst}))|$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for iter_idx in range(num_iterations):\n",
    "    # warp the reference image to the destiny with current homography\n",
    "    img_src_to_dst = warper(img_src_t, dst_homo_src())\n",
    "\n",
    "    # compute the photometric loss\n",
    "    loss = F.l1_loss(img_src_to_dst, img_dst_t, reduction='none')\n",
    "\n",
    "    # propagate the error just for a fixed window\n",
    "    w_size = 100  # window size\n",
    "    h_2, w_2 = height // 2, width // 2\n",
    "    loss = loss[..., h_2 - w_size:h_2 + w_size, w_2 - w_size:w_2 + w_size]\n",
    "    loss = torch.mean(loss)\n",
    "\n",
    "    # compute gradient and update optimizer parameters\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if iter_idx % log_interval == 0 or iter_idx == num_iterations-1:\n",
    "        print('Train iteration: {}/{}\\tLoss: {:.6}'.format(\n",
    "        iter_idx, num_iterations, loss.item()))\n",
    "        # merge warped and target image for visualization\n",
    "        img_src_to_dst = warper(img_src_t, dst_homo_src())\n",
    "        img_vis = 255. * 0.5 * (img_src_to_dst + img_dst_t)\n",
    "        img_vis_np = tgm.utils.tensor_to_image(img_vis[0, ...])\n",
    "        image_draw = draw_rectangle(img_vis_np, dst_homo_src())\n",
    "        plt.imshow(image_draw.astype('uint')[:,:,::-1])\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
